{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Binary classification"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Classification is about predicting an outcome from a fixed list of classes. The prediction is a probability distribution that assigns a probability to each possible outcome.\n",
    "\n",
    "A labeled classification sample is made up of a bunch of features and a class. The class is a boolean in the case of binary classification. We'll use the phishing dataset as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.264755Z",
     "iopub.status.busy": "2023-12-04T17:47:13.264536Z",
     "iopub.status.idle": "2023-12-04T17:47:13.668821Z",
     "shell.execute_reply": "2023-12-04T17:47:13.668532Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Phishing websites.\n",
       "\n",
       "This dataset contains features from web pages that are classified as phishing or not.\n",
       "\n",
       "    Name  Phishing                                                          \n",
       "    Task  Binary classification                                             \n",
       " Samples  1,250                                                             \n",
       "Features  9                                                                 \n",
       "  Sparse  False                                                             \n",
       "    Path  /Users/max/projects/online-ml/river/river/datasets/phishing.csv.gz"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import datasets\n",
    "\n",
    "dataset = datasets.Phishing()\n",
    "dataset"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This dataset is a streaming dataset which can be looped over."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.683729Z",
     "iopub.status.busy": "2023-12-04T17:47:13.683574Z",
     "iopub.status.idle": "2023-12-04T17:47:13.695685Z",
     "shell.execute_reply": "2023-12-04T17:47:13.695411Z"
    }
   },
   "outputs": [],
   "source": [
    "for x, y in dataset:\n",
    "    pass"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at the first sample."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.697178Z",
     "iopub.status.busy": "2023-12-04T17:47:13.697104Z",
     "iopub.status.idle": "2023-12-04T17:47:13.705637Z",
     "shell.execute_reply": "2023-12-04T17:47:13.705411Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'empty_server_form_handler': 0.0,\n",
       " 'popup_window': 0.0,\n",
       " 'https': 0.0,\n",
       " 'request_from_other_domain': 0.0,\n",
       " 'anchor_from_other_domain': 0.0,\n",
       " 'is_popular': 0.5,\n",
       " 'long_url': 1.0,\n",
       " 'age_of_domain': 1,\n",
       " 'ip_in_url': 1}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x, y = next(iter(dataset))\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.706912Z",
     "iopub.status.busy": "2023-12-04T17:47:13.706842Z",
     "iopub.status.idle": "2023-12-04T17:47:13.714796Z",
     "shell.execute_reply": "2023-12-04T17:47:13.714537Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A binary classifier's goal is to learn to predict a binary target `y` from some given features `x`. We'll try to do this with a logistic regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.716274Z",
     "iopub.status.busy": "2023-12-04T17:47:13.716198Z",
     "iopub.status.idle": "2023-12-04T17:47:13.731956Z",
     "shell.execute_reply": "2023-12-04T17:47:13.731725Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{False: 0.5, True: 0.5}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import linear_model\n",
    "\n",
    "model = linear_model.LogisticRegression()\n",
    "model.predict_proba_one(x)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model hasn't been trained on any data, and therefore outputs a default probability of 50% for each class.\n",
    "\n",
    "The model can be trained on the sample, which will update the model's state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.733503Z",
     "iopub.status.busy": "2023-12-04T17:47:13.733414Z",
     "iopub.status.idle": "2023-12-04T17:47:13.741526Z",
     "shell.execute_reply": "2023-12-04T17:47:13.741290Z"
    }
   },
   "outputs": [],
   "source": [
    "model.learn_one(x, y)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we try to make a prediction on the same sample, we can see that the probabilities are different, because the model has learned something."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.742950Z",
     "iopub.status.busy": "2023-12-04T17:47:13.742870Z",
     "iopub.status.idle": "2023-12-04T17:47:13.751313Z",
     "shell.execute_reply": "2023-12-04T17:47:13.751031Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{False: 0.494687699901455, True: 0.505312300098545}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba_one(x)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that there is also a `predict_one` if you're only interested in the most likely class rather than the probability distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.753009Z",
     "iopub.status.busy": "2023-12-04T17:47:13.752897Z",
     "iopub.status.idle": "2023-12-04T17:47:13.761857Z",
     "shell.execute_reply": "2023-12-04T17:47:13.761633Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_one(x)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Typically, an online model makes a prediction, and then learns once the ground truth reveals itself. The prediction and the ground truth can be compared to measure the model's correctness. If you have a dataset available, you can loop over it, make a prediction, update the model, and compare the model's output with the ground truth. This is called progressive validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.763372Z",
     "iopub.status.busy": "2023-12-04T17:47:13.763296Z",
     "iopub.status.idle": "2023-12-04T17:47:13.794339Z",
     "shell.execute_reply": "2023-12-04T17:47:13.794075Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ROCAUC: 89.36%"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import metrics\n",
    "\n",
    "model = linear_model.LogisticRegression()\n",
    "\n",
    "metric = metrics.ROCAUC()\n",
    "\n",
    "for x, y in dataset:\n",
    "    y_pred = model.predict_proba_one(x)\n",
    "    model.learn_one(x, y)\n",
    "    metric.update(y, y_pred)\n",
    "    \n",
    "metric"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a common way to evaluate an online model. In fact, there is a dedicated `evaluate.progressive_val_score` function that does this for you."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.795893Z",
     "iopub.status.busy": "2023-12-04T17:47:13.795799Z",
     "iopub.status.idle": "2023-12-04T17:47:13.827938Z",
     "shell.execute_reply": "2023-12-04T17:47:13.827710Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ROCAUC: 89.36%"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import evaluate\n",
    "\n",
    "model = linear_model.LogisticRegression()\n",
    "metric = metrics.ROCAUC()\n",
    "\n",
    "evaluate.progressive_val_score(dataset, model, metric)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A common way to improve the performance of a logistic regression is to scale the data. This can be done by using a `preprocessing.StandardScaler`. In particular, we can define a pipeline to organise our model into a sequence of steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.829351Z",
     "iopub.status.busy": "2023-12-04T17:47:13.829280Z",
     "iopub.status.idle": "2023-12-04T17:47:13.843109Z",
     "shell.execute_reply": "2023-12-04T17:47:13.842876Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div><div class=\"river-component river-pipeline\"><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">StandardScaler</pre></summary><code class=\"river-estimator-params\">StandardScaler (\n",
       "  with_std=True\n",
       ")\n",
       "</code></details><details class=\"river-component river-estimator\"><summary class=\"river-summary\"><pre class=\"river-estimator-name\">LogisticRegression</pre></summary><code class=\"river-estimator-params\">LogisticRegression (\n",
       "  optimizer=SGD (\n",
       "    lr=Constant (\n",
       "      learning_rate=0.01\n",
       "    )\n",
       "  )\n",
       "  loss=Log (\n",
       "    weight_pos=1.\n",
       "    weight_neg=1.\n",
       "  )\n",
       "  l2=0.\n",
       "  l1=0.\n",
       "  intercept_init=0.\n",
       "  intercept_lr=Constant (\n",
       "    learning_rate=0.01\n",
       "  )\n",
       "  clip_gradient=1e+12\n",
       "  initializer=Zeros ()\n",
       ")\n",
       "</code></details></div><style scoped>\n",
       ".river-estimator {\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "    max-width: max-content;\n",
       "}\n",
       "\n",
       ".river-pipeline {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    background: linear-gradient(#000, #000) no-repeat center / 1.5px 100%;\n",
       "}\n",
       "\n",
       ".river-union {\n",
       "    display: flex;\n",
       "    flex-direction: row;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper {\n",
       "    display: flex;\n",
       "    flex-direction: column;\n",
       "    align-items: center;\n",
       "    justify-content: center;\n",
       "    padding: 1em;\n",
       "    border-style: solid;\n",
       "    background: white;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-estimator {\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       "/* Vertical spacing between steps */\n",
       "\n",
       ".river-component + .river-component {\n",
       "    margin-top: 2em;\n",
       "}\n",
       "\n",
       ".river-union > .river-estimator {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .river-component {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       ".river-union > .pipeline {\n",
       "    margin-top: 0;\n",
       "}\n",
       "\n",
       "/* Spacing within a union of estimators */\n",
       "\n",
       ".river-union > .river-component + .river-component {\n",
       "    margin-left: 1em;\n",
       "}\n",
       "\n",
       "/* Typography */\n",
       "\n",
       ".river-estimator-params {\n",
       "    display: block;\n",
       "    white-space: pre-wrap;\n",
       "    font-size: 110%;\n",
       "    margin-top: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator > .river-estimator-params,\n",
       ".river-wrapper > .river-details > river-estimator-params {\n",
       "    background-color: white !important;\n",
       "}\n",
       "\n",
       ".river-wrapper > .river-details {\n",
       "    margin-bottom: 1em;\n",
       "}\n",
       "\n",
       ".river-estimator-name {\n",
       "    display: inline;\n",
       "    margin: 0;\n",
       "    font-size: 110%;\n",
       "}\n",
       "\n",
       "/* Toggle */\n",
       "\n",
       ".river-summary {\n",
       "    display: flex;\n",
       "    align-items:center;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".river-summary > div {\n",
       "    width: 100%;\n",
       "}\n",
       "</style></div>"
      ],
      "text/plain": [
       "Pipeline (\n",
       "  StandardScaler (\n",
       "    with_std=True\n",
       "  ),\n",
       "  LogisticRegression (\n",
       "    optimizer=SGD (\n",
       "      lr=Constant (\n",
       "        learning_rate=0.01\n",
       "      )\n",
       "    )\n",
       "    loss=Log (\n",
       "      weight_pos=1.\n",
       "      weight_neg=1.\n",
       "    )\n",
       "    l2=0.\n",
       "    l1=0.\n",
       "    intercept_init=0.\n",
       "    intercept_lr=Constant (\n",
       "      learning_rate=0.01\n",
       "    )\n",
       "    clip_gradient=1e+12\n",
       "    initializer=Zeros ()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from river import compose\n",
    "from river import preprocessing\n",
    "\n",
    "model = compose.Pipeline(\n",
    "    preprocessing.StandardScaler(),\n",
    "    linear_model.LogisticRegression()\n",
    ")\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-12-04T17:47:13.844434Z",
     "iopub.status.busy": "2023-12-04T17:47:13.844357Z",
     "iopub.status.idle": "2023-12-04T17:47:13.889821Z",
     "shell.execute_reply": "2023-12-04T17:47:13.889600Z"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ROCAUC: 95.07%"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metric = metrics.ROCAUC()\n",
    "evaluate.progressive_val_score(dataset, model, metric)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('river')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  },
  "vscode": {
   "interpreter": {
    "hash": "e6e87bad9c8c768904c061eafcb4f6739260ff8bb57f302c215ab258ded773dc"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
